{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4c8a1072",
   "metadata": {},
   "source": [
    "### P047 决策树 - 训练决策树分类模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "56b99adb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "180ac863",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import make_moons\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.tree import DecisionTreeClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "ebf8e749",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "raw_data = make_moons(n_samples=2000, noise=0.25, random_state=42)\n",
    "data = raw_data[0]\n",
    "target = raw_data[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "ed74518c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((2000, 2), (2000,))"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.shape, target.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "b497c4f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train, x_test, y_train, y_test = train_test_split(data, target)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "a88a27d1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DecisionTreeClassifier()"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "classifer = DecisionTreeClassifier()\n",
    "classifer.fit(x_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "c729001e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.902"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "classifer.score(x_test, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "57cdfebf",
   "metadata": {},
   "source": [
    "### P048 决策树 - max_depth 树的最大深度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "86877596",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DecisionTreeClassifier(max_depth=6)"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "classifer = DecisionTreeClassifier(max_depth=6)\n",
    "classifer.fit(x_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "63258668",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.928"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "classifer.score(x_test, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16d2dd81",
   "metadata": {},
   "source": [
    "### P049 决策树 - min_samples_leaf 叶节点所需的最小样本数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "836564d0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DecisionTreeClassifier(max_depth=6, min_samples_leaf=6)"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "classifer = DecisionTreeClassifier(max_depth=6, min_samples_leaf=6)\n",
    "classifer.fit(x_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "516b0329",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.93"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "classifer.score(x_test, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "796166ce",
   "metadata": {},
   "source": [
    "### P050 决策树 - 使用网格搜索获得最优的模型参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "480a7d6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import GridSearchCV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "027c99a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "params = {\n",
    "    \"max_depth\" : np.arange(1, 10),\n",
    "    \"min_samples_leaf\": np.arange(1, 20),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "a4605609",
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_search = GridSearchCV(\n",
    "    classifer,\n",
    "    param_grid=params,\n",
    "    scoring=\"accuracy\",\n",
    "    cv=5\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "2441baf5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GridSearchCV(cv=5,\n",
       "             estimator=DecisionTreeClassifier(max_depth=6, min_samples_leaf=6),\n",
       "             param_grid={'max_depth': array([1, 2, 3, 4, 5, 6, 7, 8, 9]),\n",
       "                         'min_samples_leaf': array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
       "       18, 19])},\n",
       "             scoring='accuracy')"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grid_search.fit(x_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "43cb7644",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'max_depth': 6, 'min_samples_leaf': 6}"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grid_search.best_params_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e97eeaa",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
